Skip to content

Conversation

lkhphuc
Copy link
Contributor

@lkhphuc lkhphuc commented Aug 21, 2025

First PR to onboarding modern VLM training to torchtitan.

Features:

  • Native Aspect Ratio: not limited to square crops.
  • Native Resolution: images in a batch can have different sizes, no more image tiles and thumbnails.
  • Native Interleaved data: training samples can have variable number of images, interleaved with text at different position. You can train more than just a captioning model.

Design

Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two additional hyperparameters, number of images per batch N and max image patches length L, then we pad the actual image patches to this fixed size.

Screenshot 2025-08-21 at 16 21 57
  • After tok_embedding, we obtain tokens of shape BxS.
  • After encoder, we obtain visual tokens of shape NxL.
  • We extract the valid visual tokens only
  • Then scatter those tokens to their actual positions in the LLM input tokens.

This requires the dataloader to handle the following aspect:

  • Interleave the correct precise numbers of image tokens in the inputs token based on encoder's patch size and input images' size
  • Convert images/videos to 1D sequence of patchs:
    • rearrange(pixels, 'n (t pt) (h ph) (w pw) c -> n (t h w) (pt p pw c)', pt=temporal_ps, ph=patch_size, pw=patch_size)
    • Pad all image patches sequence to a fixed length and return pixel_values.shape == [N, L, D]
  • Return a grid_thw.shape == [N, L, 3] to keep track of the location indicies of each patches in the images. Padding image can be tracked in the same tensors with values -1.

This result in a very simple and general interface to train modern VLM with interleaved data and native resolution & aspect ratio:

  • Depending on data mixtures, we can set dataloader's hyperparameters N, L to have minimal empty image padding (in batch dimension).
  • Use modern pytorch features (Flex Attention, compile etc) for efficient handling of different attention mask per (padding in sequence dimension).
  • Interface nicely with TP, PP, etc

In this PR

  • Minimal interleaved Obelics dataloader with native resolution and aspect ratio.
    • The dataloader is currently very slow, as it need to download images from internet everytime you run. (Same thing for the current imp in the multimodal experiment).
  • Siglip2 model code, mostly based on HF.
  • VLM model code called Llama3Siglip2 connecting the two vision encoder and language decoder.
  • Minimal infra code for debug model to run
Screenshot 2025-08-21 at 15 25 25

Todo:

  • Add support for captioning HF dataset that has images stored inside the dataset (CC12M like Flux exp?) so it's not super slow to load
  • Flex Attention for encoder.
  • Modify Llama3 tokenizer to add special tokens.
  • Script to combine Siglip2 + Llama3 weights and load.
  • Test Siglip2 encoder correctness.
  • Multimodal CE loss to correct for image token bias
  • All the parallelisms DP, CP, TP, PP.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 21, 2025
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the great PR! I learned a lot from this PR personally. However, I feel like the data preprocessing part in mm_collator_nld.py is a little bit hard to follow and read.

For image preprocessing, it mainly happens inmm_collator_nld.py , and the collator functions contains following steps for images:

  1. Patchify
  2. Generate Grids with coordinations
  3. Padding/ truncate
  4. Assemble as batched outputs

And text preprocessing is mainly handled in mm_dataset.py, which also contains several steps, eg padding with <image> tokens, Tokenization, mask out <image> tokens in label.

I was wondering can we future split the image and text preprocessing function into smaller code pieces, adding tensor shape hints, or even adding examples like experiments/multimodal? In this way, we could increase readability and easy to debug

The VLM modeling parts LGTM, it's clear structured!

Comment on lines 242 to 245
# Normalize with OpenAI CLIP mean/std
mean = np.array([0.48145466, 0.4578275, 0.40821073])
std = np.array([0.26862954, 0.26130258, 0.27577711])
img_array = (img_array - mean) / std
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: why we use CLIP mean/std to normalize the dataset? Is it a common practice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dependent on the pretrained vision encoder. We hardcoded it here because Siglip2 is trained with this normalization.
Ideally it should be part of the model_args because it depends on the model, and we access it here. However the current API does not expose model_args to the build_dataloader function. Should we expose it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we want to separate dataloader and model because they are unrelated, and tokenizer are the bridge between these two for text inputs. From the diagram and the code, seems that for image inputs, the VisionTransformer + projector works equivalently as bridge between Llama model and dataloader. As we only support Siglip2 encoder now, we can hardcode is and I think we shouldn't expose model_args to build_dataloader now

btw do we need to load pre-trained VisionTransformer weight iduring pre-training? Or during pre-training, it is trained together with the main llama model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we will load both pretrained siglip2 and llama3. "VLM pretraining" usually refers to starting with separately pretrained vision encoder and text decoder, only the projector connecting them are randomly init.



## Design
Distributed training usually does not play nice with input of varying shapes. To handle a varying number of images and image sizes, we requires two hyperparameters, image batch size `N` and image length `L` (in patches), and pad the actual image patches to this fixed size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think including the digram in the PR description will make this part much clear! We should figure out a way to include images in README.

dp_world_size: int = 1,
infinite: bool = False,
patch_size: int = 16,
merge_size: int = 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you briefly explain what merge_size is doing during image processing? I find a lot of functions are using merge_size, but all of them are assigned to 1. Is there any case merge_size is not 1?

This folder showcases how to train modern Vision Language Model (vlm) in torchtitan.


## Features:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is mainly describing dataloader features. Can we separate into 2 parts in README: 1) model features (What model/encoder, What you have achieved now, eg FSDP, AC, compile, and TODOs); 2) dataloader features

return module


def apply_ac(model: nn.Module, ac_config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can reuse apply_ac() function, this is a common building block under torchtitan.distributed.activation_checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't see it torchtitan.distributed.activation_checkpoint?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry my bad, the change is still in a PR: https://github.com/pytorch/torchtitan/pull/1645/files. For now we import apply_ac from torchtitan.models.llama3.infra.parallelize import apply_ac

logger.info(f"Applied {ac_config.mode} activation checkpointing to the model {type(model).__name__}")


def apply_compile(model: nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for apply_compile, apply_ddp, we could reuse this part

Co-authored-by: Jiani Wang <[email protected]>
Copy link
Contributor Author

@lkhphuc lkhphuc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your detailed review Jiani. We have substantially refactor the dataloader part as your suggestion, as well as make some modification to make it more robust.

Now we can run both interleaved Obelics dataset and CC12M captioning dataset. The CC12M dataset from HF have images included in the dataset itself, so should run substantially faster once it's downloaded.
We include both here still to demonstrate the generality of the approach, we can handle both interleave and captioning data with no code or model change.

We also include the Sample Packing features to the dataloader. This is analogous to "document packing" in Llama3 training, and only pack the LLM samples ("<|image|> text text... <|eos|> text <|image|>"). The vision encoder still operates on images with shape NxLxD where each n in N is a single image.

Image

return module


def apply_ac(model: nn.Module, ac_config):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't see it torchtitan.distributed.activation_checkpoint?

Co-Authored-By: Ankit Singh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants